# =============================================================================
# KNN (k=4) SDM Panel (Robustness) - TRUE SDM + EXACT EFFECTS
# =============================================================================
import pandas as pd
import numpy as np
import geopandas as gpd
from libpysal.weights import KNN
from spreg import Panel_FE_Lag
from scipy import stats
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from esda.moran import Moran
from datetime import datetime
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

print("=== KNN (k=4) SDM PANEL (Robustness - TRUE SDM) ===")

# --------------------------------------------------
# 1. Load the dataset
# --------------------------------------------------
df = pd.read_csv("dataset.csv")
print(f"Loaded dataset: {len(df):,} observations")

# --------------------------------------------------
# 2. Define variables
# --------------------------------------------------
exog_vars = ['spe', 'div', 'den', 'ln_fdi', 'ln_productivity_lag', 'skilled_share']

# --------------------------------------------------
# 3. Prepare spatial weights (KNN k=4)
# --------------------------------------------------
gdf = gpd.read_file("province_vn.geojson")
gdf['norm'] = gdf['ten_tinh'].str.lower().str.strip()
df['prov_norm'] = df['province_name'].str.lower().str.strip()
common = set(df['prov_norm']) & set(gdf['norm'])
df = df[df['prov_norm'].isin(common)].copy()
gdf = gdf[gdf['norm'].isin(common)].copy().reset_index(drop=True)

w = KNN.from_dataframe(gdf, k=4, use_index=False)
w.transform = 'r'
print(f"KNN (k=4) weights created: {w.n} provinces, avg neighbors = {w.mean_neighbors:.2f}")

prov_to_id = {prov: i for i, prov in enumerate(gdf['norm'])}
df['prov_id'] = df['prov_norm'].map(prov_to_id)

# --------------------------------------------------
# 4. Create WX variables (province-year averaged)
# --------------------------------------------------
W_matrix = w.full()[0]
wx_array = np.zeros((len(df), len(exog_vars)))
for i, var in enumerate(exog_vars):
    prov_year_avg = df.groupby(['year', 'prov_id'])[var].transform('mean')
    lagged_values = []
    for year in df['year'].unique():
        vec_63 = np.zeros(w.n)
        year_data = df[df['year'] == year]
        ids = year_data['prov_id'].values
        vals = prov_year_avg[year_data.index]
        vec_63[ids] = vals
        lagged = W_matrix @ vec_63
        lagged_values.append(lagged[ids])
    wx_array[:, i] = np.concatenate(lagged_values)

for i, var in enumerate(exog_vars):
    df[f'W_{var}'] = wx_array[:, i]
print("WX variables created.")

# --------------------------------------------------
# 5. Year demeaning (time FE)
# --------------------------------------------------
for col in ['ln_wage'] + exog_vars:
    df[f'{col}_demean_year'] = df[col] - df.groupby('year')[col].transform('mean')
print("Year demeaning applied.")

# --------------------------------------------------
# 6. Joint orthogonalization of WX (VIF fix)
# --------------------------------------------------
print("Joint orthogonalizing WX lags...")
X_full = df[[f'{v}_demean_year' for v in exog_vars]]
X_full_with_const = sm.add_constant(X_full)
for var in exog_vars:
    model = sm.OLS(df[f'W_{var}'], X_full_with_const).fit()
    df[f'W_{var}_joint_orth'] = model.resid
    max_corr = max(abs(df[f'W_{var}_joint_orth'].corr(df[col])) for col in X_full.columns)
    print(f"Joint orth {var}: max corr with X = {max_corr:.4f}")
print("Orthogonalization complete.")

# --------------------------------------------------
# 7. Run TRUE SDM
# --------------------------------------------------
y = df['ln_wage_demean_year'].values.reshape(-1, 1)
x_vars_sdm = [f'{v}_demean_year' for v in exog_vars] + [f'W_{v}_joint_orth' for v in exog_vars]
X = df[x_vars_sdm].values

sdm = Panel_FE_Lag(y=y, x=X, w=w,
                   name_y='ln_wage',
                   name_x=x_vars_sdm,
                   name_ds='Vietnam Province-Sector-Year 2018–2022 (KNN k=4 SDM)')

print("\n" + "="*100)
print("KNN (k=4) SDM PANEL - TRUE SDM")
print("="*100)
print(sdm.summary)

# --------------------------------------------------
# 8. Diagnostics (unchanged)
# --------------------------------------------------
x_sar = df[[f'{v}_demean_year' for v in exog_vars]].values
sar = Panel_FE_Lag(y=y, x=x_sar, w=w)
lr_stat = 2 * (sdm.logll - sar.logll)
p_lr = stats.chi2.sf(lr_stat, len(exog_vars))
print(f"\nLR common factor test: stat = {lr_stat:.3f}, p = {p_lr:.4f} → {'reject' if p_lr < 0.05 else 'fail to reject'}")

df['sdm_resid'] = sdm.u.flatten()
print("\nMoran's I by year (province-averaged residuals):")
moran_res_sdm = []
for year in sorted(df['year'].unique()):
    year_df = df[df['year'] == year]
    prov_resid = year_df.groupby('prov_id')['sdm_resid'].mean()
    resid_vec = np.zeros(w.n)
    resid_vec[prov_resid.index.astype(int)] = prov_resid.values
    mi = Moran(resid_vec, w, permutations=999)
    moran_res_sdm.append((year, mi.I, mi.p_sim))
    sig = "***" if mi.p_sim < 0.01 else "**" if mi.p_sim < 0.05 else "*" if mi.p_sim < 0.10 else ""
    print(f"Year {year}: I = {mi.I:.4f}{sig} (p={mi.p_sim:.4f})")

print("\nVIF CHECK")
X_vif = sm.add_constant(df[x_vars_sdm])
vif_data = pd.DataFrame({'feature': ['const'] + x_vars_sdm,
                         'VIF': [variance_inflation_factor(X_vif.values, i) for i in range(len(X_vif.columns))]})
print(vif_data.round(2))

# --------------------------------------------------
# 9. EXACT LeSage & Pace (2009) EFFECTS
# --------------------------------------------------
print("\n" + "="*100)
print("LeSAGE & PACE (2009) EFFECTS – EXACT (full matrix) - KNN (k=4)")
print("="*100)

rho = float(sdm.rho)
n = len(exog_vars)
betas = sdm.betas.flatten()[:n]      # X coefficients
thetas = sdm.betas.flatten()[n:2*n]  # WX_orth coefficients

W_mat = w.full()[0]
I = np.eye(w.n)
S_inv = np.linalg.inv(I - rho * W_mat)

print(f"{'Variable':20} {'Direct':>12} {'Indirect':>12} {'Total':>12}")
for i, var in enumerate(exog_vars):
    M = betas[i] * I + thetas[i] * W_mat
    effects = S_inv @ M
    direct = np.mean(np.diag(effects))
    total = np.mean(effects)
    indirect = total - direct
    print(f"{var:20} {direct:12.4f} {indirect:12.4f} {total:12.4f}")

# --------------------------------------------------
# 10. SAVE RESULTS
# --------------------------------------------------
output_file = "SDM_KNN_results.txt"
with open(output_file, "w", encoding="utf-8") as f:
    f.write("="*100 + "\n")
    f.write("KNN (k=4) SDM PANEL – TRUE SDM (Exact Effects)\n")
    f.write(f"Date: {datetime.now().strftime('%B %d, %Y at %H:%M')}\n")
    f.write("="*100 + "\n\n")
    f.write(str(sdm.summary) + "\n\n")
    f.write("LeSage & Pace Effects (EXACT):\n")
    for i, var in enumerate(exog_vars):
        M = betas[i] * I + thetas[i] * W_mat
        effects = S_inv @ M
        direct = np.mean(np.diag(effects))
        total = np.mean(effects)
        indirect = total - direct
        f.write(f"{var:20} Direct: {direct:8.4f}  Indirect: {indirect:8.4f}  Total: {total:8.4f}\n")
    f.write(f"\nLR common factor test: stat = {lr_stat:.3f}, p = {p_lr:.4f}\n")
    f.write("Moran's I on residuals:\n")
    for year, mi_val, p_val in moran_res_sdm:
        f.write(f"Year {year}: I = {mi_val:.4f} (p={p_val:.4f})\n")
    f.write("\nVIF table:\n")
    f.write(vif_data.round(2).to_string())

print(f"\n✅ KNN (k=4) results successfully saved to: {output_file}")